{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ZILNLoss\n",
    "\n",
    "**[DISCLAIMER]**\n",
    "\n",
    "Purpose of this notebook is to check if ZILNloss [implemented originaly Keras](https://github.com/google/lifetime_value/blob/master/notebooks/kdd_cup_98/regression.ipynb) give same results in pytorch-widedeep implemenatation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "gE76T8J7IsGC"
   },
   "outputs": [],
   "source": [
    "# @title Copyright 2019 The Lifetime Value Authors.\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ============================================================================"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sswTFWDv7HZd"
   },
   "source": [
    "# KDD Cup 98 LTV Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PSr1mSJP7O1J"
   },
   "source": [
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/lifetime_value/blob/master/notebooks/kdd_cup_98/regression.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/google/lifetime_value/blob/master/notebooks/kdd_cup_98/regression.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "pBXE3Dz3NI4A"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "from typing import Sequence\n",
    "\n",
    "# install and import ltv\n",
    "!pip install -q git+https://github.com/google/lifetime_value\n",
    "import lifetime_value as ltv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Bq0Ah16lBmgV"
   },
   "outputs": [],
   "source": [
    "tfd = tfp.distributions\n",
    "%config InlineBackend.figure_format='retina'\n",
    "sns.set_style(\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2qN319qZK3IG"
   },
   "source": [
    "## Configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hNy_ybw_K19n"
   },
   "outputs": [],
   "source": [
    "MODEL = \"dnn\"\n",
    "LOSS = \"ziln\"  # @param { isTemplate: true, type: 'string'} ['mse', 'ziln']\n",
    "LEARNING_RATE = 0.001  # @param { isTemplate: true}\n",
    "VERSION = 0  # @param { isTemplate: true, type: 'integer'}\n",
    "OUTPUT_CSV_FOLDER = \"/tmp/lifetime-value/kdd_cup_98/result\"  # @param { isTemplate: true, type: 'string'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mDSR921CCEcL"
   },
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lHxp4rOGI02Q"
   },
   "source": [
    "Download kdd_cup_98 data to /tmp/lifetime-value/kdd_cup_98"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Dg3qtgJyJpdi"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "mkdir -p /tmp/lifetime-value/kdd_cup_98\n",
    "wget https://kdd.ics.uci.edu/databases/kddcup98/epsilon_mirror/cup98lrn.zip -P /tmp/lifetime-value/kdd_cup_98/\n",
    "wget https://kdd.ics.uci.edu/databases/kddcup98/epsilon_mirror/cup98val.zip -P /tmp/lifetime-value/kdd_cup_98/\n",
    "wget https://kdd.ics.uci.edu/databases/kddcup98/epsilon_mirror/valtargt.txt -P /tmp/lifetime-value/kdd_cup_98/\n",
    "cd /tmp/lifetime-value/kdd_cup_98/\n",
    "unzip cup98lrn.zip\n",
    "unzip cup98val.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "a_LnLmQQRlYF"
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv(\"/tmp/lifetime-value/kdd_cup_98/cup98LRN.txt\")\n",
    "num_train = df_train.shape[0]\n",
    "df_eval = pd.read_csv(\"/tmp/lifetime-value/kdd_cup_98/cup98VAL.txt\")\n",
    "df_eval_target = pd.read_csv(\"/tmp/lifetime-value/kdd_cup_98/valtargt.txt\")\n",
    "df_eval = df_eval.merge(df_eval_target, on=\"CONTROLN\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ggQmy9wiP5M6"
   },
   "outputs": [],
   "source": [
    "df = pd.concat([df_train, df_eval], axis=0, sort=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0rgxHpIyjaMH"
   },
   "source": [
    "## Label distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Xmpu_d3YjcFC"
   },
   "outputs": [],
   "source": [
    "y = df[\"TARGET_D\"][:num_train]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yMr2EDRyK5Sb"
   },
   "outputs": [],
   "source": [
    "def plot_hist_log_scale(y):\n",
    "    max_val = y.max() + 1.0\n",
    "    ax = pd.Series(y).hist(\n",
    "        figsize=(8, 5), bins=10 ** np.linspace(0.0, np.log10(max_val), 20)\n",
    "    )\n",
    "\n",
    "    plt.xlabel(\"Donation ($)\")\n",
    "    plt.ylabel(\"Count\")\n",
    "    # plt.title('Histogram of LTV')\n",
    "    plt.xticks(rotation=\"horizontal\")\n",
    "    plt.legend(loc=\"upper left\")\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.grid(False)\n",
    "    # Hide the right and top spines\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "    # Only show ticks on the left and bottom spines\n",
    "    ax.yaxis.set_ticks_position(\"left\")\n",
    "    ax.xaxis.set_ticks_position(\"bottom\")\n",
    "    plt.show()\n",
    "\n",
    "    fig = ax.get_figure()\n",
    "    output_file = tf.io.gfile.GFile(\n",
    "        \"/tmp/lifetime-value/kdd_cup_98/histogram_kdd98_log_scale.pdf\", \"wb\"\n",
    "    )\n",
    "    fig.savefig(output_file, bbox_inches=\"tight\", format=\"pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KbwCzGkBOWhH"
   },
   "outputs": [],
   "source": [
    "plot_hist_log_scale(y[y > 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1XXMLbnlCdlN"
   },
   "source": [
    "## Preprocess features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L1sBf_RSU3pR"
   },
   "source": [
    "### Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xB_ddsd_U_4e"
   },
   "outputs": [],
   "source": [
    "VOCAB_FEATURES = [\n",
    "    \"ODATEDW\",  # date of donor's first gift (YYMM)\n",
    "    \"OSOURCE\",  # donor acquisition mailing list\n",
    "    \"TCODE\",  # donor title code\n",
    "    \"STATE\",\n",
    "    \"ZIP\",\n",
    "    \"DOMAIN\",  # urbanicity level and socio-economic status of the neighborhood\n",
    "    \"CLUSTER\",  # socio-economic status\n",
    "    \"GENDER\",\n",
    "    \"MAXADATE\",  # date of the most recent promotion received\n",
    "    \"MINRDATE\",\n",
    "    \"LASTDATE\",\n",
    "    \"FISTDATE\",\n",
    "    \"RFA_2A\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f2oPZGVLRSPe"
   },
   "outputs": [],
   "source": [
    "df[\"ODATEDW\"] = df[\"ODATEDW\"].astype(\"str\")\n",
    "df[\"TCODE\"] = df[\"TCODE\"].apply(lambda x: \"{:03d}\".format(x // 1000 if x > 1000 else x))\n",
    "df[\"ZIP\"] = df[\"ZIP\"].str.slice(0, 5)\n",
    "df[\"MAXADATE\"] = df[\"MAXADATE\"].astype(\"str\")\n",
    "df[\"MINRDATE\"] = df[\"MINRDATE\"].astype(\"str\")\n",
    "df[\"LASTDATE\"] = df[\"LASTDATE\"].astype(\"str\")\n",
    "df[\"FISTDATE\"] = df[\"FISTDATE\"].astype(\"str\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "isL9Ofv9JLAP"
   },
   "outputs": [],
   "source": [
    "def label_encoding(y, frequency_threshold=100):\n",
    "    value_counts = pd.value_counts(y)\n",
    "    categories = value_counts[value_counts >= frequency_threshold].index.to_numpy()\n",
    "    # 0 indicates the unknown category.\n",
    "    return pd.Categorical(y, categories=categories).codes + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BgXGO5D0OdJP"
   },
   "outputs": [],
   "source": [
    "for key in VOCAB_FEATURES:\n",
    "    df[key] = label_encoding(df[key])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kZkmnJ93Zrjw"
   },
   "source": [
    "### Indicator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tGBpMfaGhCD0"
   },
   "outputs": [],
   "source": [
    "MAIL_ORDER_RESPONSES = [\n",
    "    \"MBCRAFT\",\n",
    "    \"MBGARDEN\",\n",
    "    \"MBBOOKS\",\n",
    "    \"MBCOLECT\",\n",
    "    \"MAGFAML\",\n",
    "    \"MAGFEM\",\n",
    "    \"MAGMALE\",\n",
    "    \"PUBGARDN\",\n",
    "    \"PUBCULIN\",\n",
    "    \"PUBHLTH\",\n",
    "    \"PUBDOITY\",\n",
    "    \"PUBNEWFN\",\n",
    "    \"PUBPHOTO\",\n",
    "    \"PUBOPP\",\n",
    "    \"RFA_2F\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4V-DeOZFZhjB"
   },
   "outputs": [],
   "source": [
    "INDICATOR_FEATURES = [\n",
    "    \"AGE\",  # age decile, 0 indicates unknown\n",
    "    \"NUMCHLD\",\n",
    "    \"INCOME\",\n",
    "    \"WEALTH1\",\n",
    "    \"HIT\",\n",
    "] + MAIL_ORDER_RESPONSES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "U9y5qA1vZ0kz"
   },
   "outputs": [],
   "source": [
    "df[\"AGE\"] = pd.qcut(df[\"AGE\"].values, 10).codes + 1\n",
    "df[\"NUMCHLD\"] = df[\"NUMCHLD\"].apply(lambda x: 0 if np.isnan(x) else int(x))\n",
    "df[\"INCOME\"] = df[\"INCOME\"].apply(lambda x: 0 if np.isnan(x) else int(x))\n",
    "df[\"WEALTH1\"] = df[\"WEALTH1\"].apply(lambda x: 0 if np.isnan(x) else int(x) + 1)\n",
    "df[\"HIT\"] = pd.qcut(df[\"HIT\"].values, q=50, duplicates=\"drop\").codes\n",
    "\n",
    "for col in MAIL_ORDER_RESPONSES:\n",
    "    df[col] = pd.qcut(df[col].values, q=20, duplicates=\"drop\").codes + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8DOO_2a-U6gr"
   },
   "source": [
    "### Numeric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rqVteSLDiLVr"
   },
   "outputs": [],
   "source": [
    "NUMERIC_FEATURES = [\n",
    "    # binary\n",
    "    \"MAILCODE\",  # bad address\n",
    "    \"NOEXCH\",  # do not exchange\n",
    "    \"RECINHSE\",  # donor has given to PVA's in house program\n",
    "    \"RECP3\",  # donor has given to PVA's P3 program\n",
    "    \"RECPGVG\",  # planned giving record\n",
    "    \"RECSWEEP\",  # sweepstakes record\n",
    "    \"HOMEOWNR\",  # home owner\n",
    "    \"CHILD03\",\n",
    "    \"CHILD07\",\n",
    "    \"CHILD12\",\n",
    "    \"CHILD18\",\n",
    "    # continuous\n",
    "    \"CARDPROM\",\n",
    "    \"NUMPROM\",\n",
    "    \"CARDPM12\",\n",
    "    \"NUMPRM12\",\n",
    "    \"RAMNTALL\",\n",
    "    \"NGIFTALL\",\n",
    "    \"MINRAMNT\",\n",
    "    \"MAXRAMNT\",\n",
    "    \"LASTGIFT\",\n",
    "    \"AVGGIFT\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xMRP05Ztic0A"
   },
   "outputs": [],
   "source": [
    "df[\"MAILCODE\"] = (df[\"MAILCODE\"] == \"B\").astype(\"float32\")\n",
    "df[\"PVASTATE\"] = df[\"PVASTATE\"].isin([\"P\", \"E\"]).astype(\"float32\")\n",
    "df[\"NOEXCH\"] = df[\"NOEXCH\"].isin([\"X\", \"1\"]).astype(\"float32\")\n",
    "df[\"RECINHSE\"] = (df[\"RECINHSE\"] == \"X\").astype(\"float32\")\n",
    "df[\"RECP3\"] = (df[\"RECP3\"] == \"X\").astype(\"float32\")\n",
    "df[\"RECPGVG\"] = (df[\"RECPGVG\"] == \"X\").astype(\"float32\")\n",
    "df[\"RECSWEEP\"] = (df[\"RECSWEEP\"] == \"X\").astype(\"float32\")\n",
    "df[\"HOMEOWNR\"] = (df[\"HOMEOWNR\"] == \"H\").astype(\"float32\")\n",
    "df[\"CHILD03\"] = df[\"CHILD03\"].isin([\"M\", \"F\", \"B\"]).astype(\"float32\")\n",
    "df[\"CHILD07\"] = df[\"CHILD07\"].isin([\"M\", \"F\", \"B\"]).astype(\"float32\")\n",
    "df[\"CHILD12\"] = df[\"CHILD12\"].isin([\"M\", \"F\", \"B\"]).astype(\"float32\")\n",
    "df[\"CHILD18\"] = df[\"CHILD18\"].isin([\"M\", \"F\", \"B\"]).astype(\"float32\")\n",
    "\n",
    "df[\"CARDPROM\"] = df[\"CARDPROM\"] / 100\n",
    "df[\"NUMPROM\"] = df[\"NUMPROM\"] / 100\n",
    "df[\"CARDPM12\"] = df[\"CARDPM12\"] / 100\n",
    "df[\"NUMPRM12\"] = df[\"NUMPRM12\"] / 100\n",
    "df[\"RAMNTALL\"] = np.log1p(df[\"RAMNTALL\"])\n",
    "df[\"NGIFTALL\"] = np.log1p(df[\"NGIFTALL\"])\n",
    "df[\"MINRAMNT\"] = np.log1p(df[\"MINRAMNT\"])\n",
    "df[\"MAXRAMNT\"] = np.log1p(df[\"MAXRAMNT\"])\n",
    "df[\"LASTGIFT\"] = np.log1p(df[\"LASTGIFT\"])\n",
    "df[\"AVGGIFT\"] = np.log1p(df[\"AVGGIFT\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GoLg1PvWuCT_"
   },
   "source": [
    "### All"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lSnNgjBCuJdb"
   },
   "outputs": [],
   "source": [
    "CATEGORICAL_FEATURES = VOCAB_FEATURES + INDICATOR_FEATURES\n",
    "ALL_FEATURES = CATEGORICAL_FEATURES + NUMERIC_FEATURES"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8HJBvvCxRPg3"
   },
   "source": [
    "## Train/eval split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N7BXLB1eHovl"
   },
   "outputs": [],
   "source": [
    "def dnn_split(df):\n",
    "    df_train = df.iloc[:num_train]\n",
    "    df_eval = df.iloc[num_train:]\n",
    "\n",
    "    def feature_dict(df):\n",
    "        features = {k: v.values for k, v in dict(df[CATEGORICAL_FEATURES]).items()}\n",
    "        features[\"numeric\"] = df[NUMERIC_FEATURES].astype(\"float32\").values\n",
    "        return features\n",
    "\n",
    "    x_train, y_train = (\n",
    "        feature_dict(df_train),\n",
    "        df_train[\"TARGET_D\"].astype(\"float32\").values,\n",
    "    )\n",
    "    x_eval, y_eval = feature_dict(df_eval), df_eval[\"TARGET_D\"].astype(\"float32\").values\n",
    "\n",
    "    return x_train, x_eval, y_train, y_eval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4yw6fekBtX7X"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_rIuO0XYtZH2"
   },
   "outputs": [],
   "source": [
    "def embedding_dim(x):\n",
    "    return int(x**0.25) + 1\n",
    "\n",
    "\n",
    "def embedding_layer(vocab_size):\n",
    "    return tf.keras.Sequential(\n",
    "        [\n",
    "            tf.keras.layers.Embedding(\n",
    "                input_dim=vocab_size,\n",
    "                output_dim=embedding_dim(vocab_size),\n",
    "                input_length=1,\n",
    "            ),\n",
    "            tf.keras.layers.Flatten(),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "\n",
    "def dnn_model(output_units):\n",
    "    numeric_input = tf.keras.layers.Input(\n",
    "        shape=(len(NUMERIC_FEATURES),), name=\"numeric\"\n",
    "    )\n",
    "\n",
    "    embedding_inputs = [\n",
    "        tf.keras.layers.Input(shape=(1,), name=key, dtype=np.int64)\n",
    "        for key in CATEGORICAL_FEATURES\n",
    "    ]\n",
    "\n",
    "    embedding_outputs = [\n",
    "        embedding_layer(vocab_size=df[key].max() + 1)(input)\n",
    "        for key, input in zip(CATEGORICAL_FEATURES, embedding_inputs)\n",
    "    ]\n",
    "\n",
    "    deep_input = tf.keras.layers.concatenate([numeric_input] + embedding_outputs)\n",
    "    deep_model = tf.keras.Sequential(\n",
    "        [\n",
    "            tf.keras.layers.Dense(128, activation=\"relu\"),\n",
    "            tf.keras.layers.Dense(128, activation=\"relu\"),\n",
    "            tf.keras.layers.Dense(64, activation=\"relu\"),\n",
    "            tf.keras.layers.Dense(64, activation=\"relu\"),\n",
    "            tf.keras.layers.Dense(units=output_units),\n",
    "        ]\n",
    "    )\n",
    "    return tf.keras.Model(\n",
    "        inputs=[numeric_input] + embedding_inputs, outputs=deep_model(deep_input)\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G5h7X6botcHl"
   },
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iJ9gpkC6tgP0"
   },
   "outputs": [],
   "source": [
    "if LOSS == \"mse\":\n",
    "    loss = tf.keras.losses.MeanSquaredError()\n",
    "    output_units = 1\n",
    "\n",
    "if LOSS == \"ziln\":\n",
    "    loss = ltv.zero_inflated_lognormal_loss\n",
    "    output_units = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_afFfIritjCM"
   },
   "outputs": [],
   "source": [
    "x_train, x_eval, y_train, y_eval = dnn_split(df)\n",
    "model = dnn_model(output_units)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qj3kI7pyVwzO"
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer=tf.keras.optimizers.Nadam(lr=LEARNING_RATE), loss=loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KZSYxgWdwiXC"
   },
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Nwj9h5ysQDLp"
   },
   "outputs": [],
   "source": [
    "callbacks = [\n",
    "    tf.keras.callbacks.ReduceLROnPlateau(monitor=\"val_loss\", min_lr=1e-6),\n",
    "    tf.keras.callbacks.EarlyStopping(monitor=\"val_loss\", patience=10),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vb5Tnld6hsfx"
   },
   "outputs": [],
   "source": [
    "history = model.fit(\n",
    "    x=x_train,\n",
    "    y=y_train,\n",
    "    batch_size=2048,\n",
    "    epochs=200,\n",
    "    verbose=2,\n",
    "    callbacks=callbacks,\n",
    "    validation_data=(x_eval, y_eval),\n",
    ").history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J1sLSUdgvfa6"
   },
   "outputs": [],
   "source": [
    "pd.DataFrame(history)[[\"loss\", \"val_loss\"]].plot();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jRKuZBqhvhT9"
   },
   "source": [
    "## Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "q9_zNMd3vjNk"
   },
   "outputs": [],
   "source": [
    "if LOSS == \"mse\":\n",
    "    y_pred = model.predict(x=x_eval, batch_size=1024).flatten()\n",
    "\n",
    "if LOSS == \"ziln\":\n",
    "    logits = model.predict(x=x_eval, batch_size=1024)\n",
    "    y_pred = ltv.zero_inflated_lognormal_pred(logits).numpy().flatten()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pytorch-widedeep approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.preprocessing import TabPreprocessor\n",
    "from pytorch_widedeep.training import Trainer\n",
    "from pytorch_widedeep.models import TabMlp, WideDeep\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from pytorch_widedeep.callbacks import EarlyStopping\n",
    "from torch.optim import NAdam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CATEGORICAL_FEATURES\n",
    "NUMERICAL_FEATURES = [\"num\" + str(i) for i in range(21)]\n",
    "x_train_pyt_num = pd.DataFrame(columns=NUMERICAL_FEATURES, data=x_train[\"numeric\"])\n",
    "x_train_pyt_cat = pd.DataFrame(\n",
    "    {key: value for key, value in x_train.items() if key not in [\"numeric\"]}\n",
    ")\n",
    "\n",
    "x_eval_pyt_num = pd.DataFrame(columns=NUMERICAL_FEATURES, data=x_eval[\"numeric\"])\n",
    "x_eval_pyt_cat = pd.DataFrame(\n",
    "    {key: value for key, value in x_eval.items() if key not in [\"numeric\"]}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train_pyt = pd.concat([x_train_pyt_num, x_train_pyt_cat], axis=1)\n",
    "x_eval_pyt = pd.concat([x_eval_pyt_num, x_eval_pyt_cat], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_input = [\n",
    "    (u, int(x_train_pyt[u].nunique() ** 0.25) + 1) for u in CATEGORICAL_FEATURES\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# deeptabular\n",
    "tab_preprocessor = TabPreprocessor(\n",
    "    embed_cols=embed_input,\n",
    "    continuous_cols=NUMERICAL_FEATURES,\n",
    "    shared_embed=False,\n",
    "    scale=False,\n",
    ")\n",
    "X_tab_train = tab_preprocessor.fit_transform(x_train_pyt)\n",
    "X_tab_valid = tab_preprocessor.transform(x_eval_pyt)\n",
    "X_tab_test = tab_preprocessor.transform(x_eval_pyt)\n",
    "\n",
    "# target\n",
    "y_train = y_train\n",
    "y_valid = y_eval\n",
    "y_test = y_train\n",
    "\n",
    "X_train = {\"X_tab\": X_tab_train, \"target\": y_train}\n",
    "X_val = {\"X_tab\": X_tab_valid, \"target\": y_valid}\n",
    "X_test = {\"X_tab\": X_tab_test}\n",
    "\n",
    "deeptabular = TabMlp(\n",
    "    mlp_hidden_dims=[128, 128, 64, 64],\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    embed_input=tab_preprocessor.cat_embed_input,\n",
    "    continuous_cols=tab_preprocessor.continuous_cols,\n",
    ")\n",
    "\n",
    "model = WideDeep(deeptabular=deeptabular, pred_dim=3)\n",
    "\n",
    "deep_opt = NAdam(model.deeptabular.parameters(), lr=LEARNING_RATE)\n",
    "callbacks = [EarlyStopping()]\n",
    "deep_sch = ReduceLROnPlateau(deep_opt, min_lr=1e-6)\n",
    "\n",
    "objective = \"ziln\"\n",
    "\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    callbacks=callbacks,\n",
    "    lr_schedulers={\"deeptabular\": deep_sch},\n",
    "    objective=objective,\n",
    "    optimizers={\"deeptabular\": deep_opt},\n",
    ")\n",
    "\n",
    "trainer.fit(\n",
    "    X_train=X_train,\n",
    "    X_val=X_val,\n",
    "    n_epochs=200,\n",
    "    batch_size=2048,\n",
    ")\n",
    "\n",
    "y_pred_pytorch = trainer.predict(X_test=X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(trainer.history)[[\"train_loss\", \"val_loss\"]].plot();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "mean_squared_error(y_pred, y_pred_pytorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Appendix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SkfkUMUvUu_E"
   },
   "source": [
    "### Total Profit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AwfWAp8WQuns"
   },
   "outputs": [],
   "source": [
    "unit_costs = [0.4, 0.5, 0.6, 0.68, 0.7, 0.8, 0.9, 1.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zqi91dfCUxpx"
   },
   "outputs": [],
   "source": [
    "num_mailed = [np.sum(y_pred > v) for v in unit_costs]\n",
    "num_mailed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZgFjZUcuhScv"
   },
   "outputs": [],
   "source": [
    "baseline_total_profit = np.sum(y_eval - 0.68)\n",
    "baseline_total_profit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VwsFnin5U-R9"
   },
   "outputs": [],
   "source": [
    "total_profits = [np.sum(y_eval[y_pred > v] - v) for v in unit_costs]\n",
    "total_profits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zROhsEWxnA5u"
   },
   "source": [
    "### Gini Coefficient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gRsJ7y-632h_"
   },
   "outputs": [],
   "source": [
    "gain = pd.DataFrame(\n",
    "    {\n",
    "        \"lorenz\": ltv.cumulative_true(y_eval, y_eval),\n",
    "        \"baseline\": ltv.cumulative_true(y_eval, x_eval[\"numeric\"][:, 19]),\n",
    "        \"model\": ltv.cumulative_true(y_eval, y_pred),\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yg-ndbve4AL_"
   },
   "outputs": [],
   "source": [
    "num_customers = np.float32(gain.shape[0])\n",
    "gain[\"cumulative_customer\"] = (np.arange(num_customers) + 1.0) / num_customers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WEoAvuCj4OVy"
   },
   "outputs": [],
   "source": [
    "ax = gain[\n",
    "    [\n",
    "        \"cumulative_customer\",\n",
    "        \"lorenz\",\n",
    "        \"baseline\",\n",
    "        \"model\",\n",
    "    ]\n",
    "].plot(x=\"cumulative_customer\", figsize=(8, 5), legend=True)\n",
    "\n",
    "ax.legend([\"Groundtruth\", \"Baseline\", \"Model\"], loc=\"lower right\")\n",
    "\n",
    "ax.set_xlabel(\"Cumulative Fraction of Customers\")\n",
    "ax.set_xticks(np.arange(0, 1.1, 0.1))\n",
    "ax.set_xlim((0, 1.0))\n",
    "\n",
    "ax.set_ylabel(\"Cumulative Fraction of Total Lifetime Value\")\n",
    "ax.set_yticks(np.arange(0, 1.1, 0.1))\n",
    "ax.set_ylim((0, 1.05))\n",
    "ax.set_title(\"Gain Chart\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kzPqaiNO4iWC"
   },
   "outputs": [],
   "source": [
    "gini = ltv.gini_from_gain(gain[[\"lorenz\", \"baseline\", \"model\"]])\n",
    "gini"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S84RitIa9PBu"
   },
   "source": [
    "### Calibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "X7sKbsEf6RvF"
   },
   "outputs": [],
   "source": [
    "df_decile = ltv.decile_stats(y_eval, y_pred)\n",
    "df_decile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DHdLqUqdL4hf"
   },
   "outputs": [],
   "source": [
    "ax = df_decile[[\"label_mean\", \"pred_mean\"]].plot.bar(rot=0)\n",
    "\n",
    "ax.set_title(\"Decile Chart\")\n",
    "ax.set_xlabel(\"Prediction bucket\")\n",
    "ax.set_ylabel(\"Average bucket value\")\n",
    "ax.legend([\"Label\", \"Prediction\"], loc=\"upper left\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nK6DQ89xU-d4"
   },
   "source": [
    "### Rank Correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "I9qWGyY3WePz"
   },
   "outputs": [],
   "source": [
    "def spearmanr(x1: Sequence[float], x2: Sequence[float]) -> float:\n",
    "    \"\"\"Calculates spearmanr rank correlation coefficient.\n",
    "\n",
    "    See https://docs.scipy.org/doc/scipy/reference/stats.html.\n",
    "\n",
    "    Args:\n",
    "      x1: 1D array_like.\n",
    "      x2: 1D array_like.\n",
    "\n",
    "    Returns:\n",
    "      correlation: float.\n",
    "    \"\"\"\n",
    "    return stats.spearmanr(x1, x2, nan_policy=\"raise\")[0]\n",
    "\n",
    "\n",
    "spearman_corr = spearmanr(y_eval, y_pred)\n",
    "spearman_corr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-i_AbqhXcurk"
   },
   "source": [
    "### All metrics together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Umqg1-0Bc1HS"
   },
   "outputs": [],
   "source": [
    "df_metrics = pd.DataFrame(\n",
    "    {\n",
    "        \"model\": MODEL,\n",
    "        \"loss_function\": LOSS,\n",
    "        \"train_loss\": history[\"loss\"][-1],\n",
    "        \"eval_loss\": history[\"val_loss\"][-1],\n",
    "        \"label_positive\": np.mean(y_eval > 0),\n",
    "        \"label_mean\": y_eval.mean(),\n",
    "        \"pred_mean\": y_pred.mean(),\n",
    "        \"decile_mape\": df_decile[\"decile_mape\"].mean(),\n",
    "        \"baseline_gini\": gini[\"normalized\"][1],\n",
    "        \"gini\": gini[\"normalized\"][2],\n",
    "        \"spearman_corr\": spearman_corr,\n",
    "    },\n",
    "    index=[VERSION],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "C_cM2Mc2SB3W"
   },
   "outputs": [],
   "source": [
    "for unit_cost, total_profit in zip(unit_costs, total_profits):\n",
    "    df_metrics[\"total_profit_{:02d}\".format(int(unit_cost * 100))] = total_profit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iyMvsOtbRrXZ"
   },
   "outputs": [],
   "source": [
    "df_metrics.T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8uHtLKk1x0IE"
   },
   "source": [
    "## Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "L-fMkqWIm6X6"
   },
   "outputs": [],
   "source": [
    "output_path = OUTPUT_CSV_FOLDER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jpJJAbWEm94h"
   },
   "outputs": [],
   "source": [
    "if not os.path.isdir(output_path):\n",
    "    os.makedirs(output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y4LcrTLOm_4B"
   },
   "outputs": [],
   "source": [
    "output_file = os.path.join(\n",
    "    output_path, \"{}_regression_{}_{}.csv\".format(MODEL, LOSS, VERSION)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4WOF7a-dnENp"
   },
   "outputs": [],
   "source": [
    "df_metrics.to_csv(output_file, index=False)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "",
    "kind": "local"
   },
   "name": "regression.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "vscode": {
   "interpreter": {
    "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
